import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.io import loadmat
from sklearn.model_selection import train_test_split
from torch import optim
from torch.utils.data import DataLoader
from data import CustomDataset
from matching_network import MatchingNetwork, matching_loss, compute_accuracy, process_batch_matching
from Test import PrototypicalNetwork, prototypical_loss, compute_prototypes, process_batch
import matplotlib.pyplot as plt
from tqdm import tqdm
import time


def set_seed(seed):
    """设置随机种子以确保结果可重现"""
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def train_prototypical_network(train_loader, test_loader, device, num_classes, num_support, num_epochs=50):
    """训练Prototypical Network"""
    print("训练 Prototypical Network...")

    # 网络参数
    input_size = 8
    hidden_size = 256
    output_size = 64

    # 创建模型
    proto_net = PrototypicalNetwork(input_size, hidden_size, output_size)
    proto_net.to(device)
    optimizer = optim.Adam(proto_net.parameters(), lr=0.001)

    train_losses = []
    test_accuracies = []

    for epoch in range(num_epochs):
        # 训练
        proto_net.train()
        epoch_loss = 0
        batch_count = 0

        for batch in tqdm(train_loader, desc=f'Proto Epoch {epoch + 1}'):
            prototypes, query_embeddings, query_labels = process_batch(
                proto_net, batch, device, num_classes, num_support
            )

            loss = prototypical_loss(prototypes, query_embeddings, query_labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            batch_count += 1

        avg_loss = epoch_loss / batch_count
        train_losses.append(avg_loss)

        # 测试
        proto_net.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for batch in test_loader:
                prototypes, query_embeddings, query_labels = process_batch(
                    proto_net, batch, device, num_classes, num_support
                )

                distances = torch.cdist(query_embeddings, prototypes)
                _, predicted = torch.min(distances, dim=1)

                total += query_labels.size(0)
                correct += (predicted == query_labels).sum().item()

        accuracy = 100 * correct / total
        test_accuracies.append(accuracy)

        if epoch % 10 == 0:
            print(f'Proto Epoch {epoch + 1}: Loss={avg_loss:.4f}, Accuracy={accuracy:.2f}%')

    return proto_net, train_losses, test_accuracies


def train_matching_network(train_loader, test_loader, device, num_classes, num_support, num_epochs=50):
    """训练Matching Network"""
    print("训练 Matching Network...")

    # 网络参数
    input_size = 8
    hidden_size = 256
    output_size = 64

    # 创建模型
    matching_net = MatchingNetwork(
        input_size=input_size,
        hidden_size=hidden_size,
        output_size=output_size,
        use_fce=True,
        attention_type='cosine'
    )
    matching_net.to(device)
    optimizer = optim.Adam(matching_net.parameters(), lr=0.001)

    train_losses = []
    test_accuracies = []

    for epoch in range(num_epochs):
        # 训练
        matching_net.train()
        epoch_loss = 0
        batch_count = 0

        for batch in tqdm(train_loader, desc=f'Match Epoch {epoch + 1}'):
            predictions, query_labels = process_batch_matching(
                matching_net, batch, device, num_classes, num_support
            )

            loss = matching_loss(predictions, query_labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            batch_count += 1

        avg_loss = epoch_loss / batch_count
        train_losses.append(avg_loss)

        # 测试
        matching_net.eval()
        total_accuracy = 0
        test_batch_count = 0

        with torch.no_grad():
            for batch in test_loader:
                predictions, query_labels = process_batch_matching(
                    matching_net, batch, device, num_classes, num_support
                )

                accuracy = compute_accuracy(predictions, query_labels)
                total_accuracy += accuracy
                test_batch_count += 1

        avg_accuracy = (total_accuracy / test_batch_count) * 100
        test_accuracies.append(avg_accuracy)

        if epoch % 10 == 0:
            print(f'Match Epoch {epoch + 1}: Loss={avg_loss:.4f}, Accuracy={avg_accuracy:.2f}%')

    return matching_net, train_losses, test_accuracies


def plot_comparison(proto_losses, proto_accuracies, match_losses, match_accuracies):
    """绘制对比图表"""
    epochs = range(1, len(proto_losses) + 1)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    # 损失对比
    ax1.plot(epochs, proto_losses, 'b-', label='Prototypical Network', linewidth=2)
    ax1.plot(epochs, match_losses, 'r-', label='Matching Network', linewidth=2)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Training Loss')
    ax1.set_title('训练损失对比')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # 准确率对比
    ax2.plot(epochs, proto_accuracies, 'b-', label='Prototypical Network', linewidth=2)
    ax2.plot(epochs, match_accuracies, 'r-', label='Matching Network', linewidth=2)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Test Accuracy (%)')
    ax2.set_title('测试准确率对比')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig('networks_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()


def main():
    # 设置随机种子
    seed = 42
    set_seed(seed)

    # 设备配置
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用设备: {device}")

    # 数据集参数
    num_classes = 10
    num_support = 10
    num_query = 15
    num_epochs = 20  # 进一步减少epoch数以便快速测试

    # 加载数据
    print("加载数据...")
    dictdata = loadmat("traindata_5dB.mat")
    dataload = dictdata['traindata']

    # 创建数据集
    dataset = CustomDataset(dataload, noisy=False)
    train_dataset, test_dataset = train_test_split(dataset, test_size=0.9, random_state=42)

    batch_size = num_classes * (num_support + num_query)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

    print(f"训练集大小: {len(train_dataset)}")
    print(f"测试集大小: {len(test_dataset)}")
    print(f"批次大小: {batch_size}")

    # 记录开始时间
    start_time = time.time()

    # 训练Prototypical Network
    proto_net, proto_losses, proto_accuracies = train_prototypical_network(
        train_loader, test_loader, device, num_classes, num_support, num_epochs
    )

    proto_time = time.time() - start_time

    # 训练Matching Network
    match_start_time = time.time()
    matching_net, match_losses, match_accuracies = train_matching_network(
        train_loader, test_loader, device, num_classes, num_support, num_epochs
    )

    match_time = time.time() - match_start_time

    # 打印结果总结
    print("\\n" + "="*60)
    print("训练结果总结")
    print("="*60)
    print(f"Prototypical Network:")
    print(f"  最终训练损失: {proto_losses[-1]:.4f}")
    print(f"  最终测试准确率: {proto_accuracies[-1]:.2f}%")
    print(f"  最佳测试准确率: {max(proto_accuracies):.2f}%")
    print(f"  训练时间: {proto_time:.2f} 秒")

    print(f"\\nMatching Network:")
    print(f"  最终训练损失: {match_losses[-1]:.4f}")
    print(f"  最终测试准确率: {match_accuracies[-1]:.2f}%")
    print(f"  最佳测试准确率: {max(match_accuracies):.2f}%")
    print(f"  训练时间: {match_time:.2f} 秒")

    # 保存模型
    torch.save(proto_net.state_dict(), 'comparison_prototypical_network.pth')
    torch.save(matching_net.state_dict(), 'comparison_matching_network.pth')

    # 绘制对比图表
    plot_comparison(proto_losses, proto_accuracies, match_losses, match_accuracies)

    # 保存详细结果
    results = {
        'prototypical': {
            'losses': proto_losses,
            'accuracies': proto_accuracies,
            'best_accuracy': max(proto_accuracies),
            'final_accuracy': proto_accuracies[-1],
            'training_time': proto_time
        },
        'matching': {
            'losses': match_losses,
            'accuracies': match_accuracies,
            'best_accuracy': max(match_accuracies),
            'final_accuracy': match_accuracies[-1],
            'training_time': match_time
        }
    }

    torch.save(results, 'comparison_results.pth')

    print(f"\\n对比结果已保存到: comparison_results.pth")
    print(f"对比图表已保存到: networks_comparison.png")
    print("对比实验完成！")


if __name__ == '__main__':
    main()
